{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1097e33-73b2-4001-a76f-782c0cd17644",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import IPython.display as ipd\n",
    "\n",
    "import os\n",
    "import json\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import commons\n",
    "import utils\n",
    "from data_utils import (\n",
    "    TextAudioLoader,\n",
    "    TextAudioCollate,\n",
    "    TextAudioSpeakerLoader,\n",
    "    TextAudioSpeakerCollate,\n",
    ")\n",
    "from models import SynthesizerTrn\n",
    "from text.symbols import symbols\n",
    "from text import text_to_sequence\n",
    "\n",
    "from scipy.io.wavfile import write\n",
    "\n",
    "\n",
    "def get_text(text, hps):\n",
    "    text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
    "    if hps.data.add_blank:\n",
    "        text_norm = commons.intersperse(text_norm, 0)\n",
    "    text_norm = torch.LongTensor(text_norm)\n",
    "    return text_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b653abf-4b03-47d6-80c3-8a4405ba9e56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec35f535-6d34-467c-bf33-aa7c1e673f34",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# LJSpeech"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb5302d7-7e4b-41c2-a39e-4fec7e46403c",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = utils.get_hparams_from_file(\"./configs/vits2_ljs_base.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "074d0c26-5a06-4878-b8da-361f68bd45c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "net_g = SynthesizerTrn(\n",
    "    len(symbols),\n",
    "    hps.data.filter_length // 2 + 1,\n",
    "    hps.train.segment_size // hps.data.hop_length,\n",
    "    **hps.model).cuda()\n",
    "_ = net_g.eval()\n",
    "\n",
    "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "196962fc-bd97-4b00-9194-e7c0d899e69a",
   "metadata": {},
   "outputs": [],
   "source": [
    "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
    "with torch.no_grad():\n",
    "    x_tst = stn_tst.cuda().unsqueeze(0)\n",
    "    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
    "    audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
    "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "267ffb0b-d2f0-4e1a-9450-5fc593075810",
   "metadata": {},
   "source": [
    "# VCTK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51a2ddc8-3cf2-459c-837f-a3fb08261b7d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hps = utils.get_hparams_from_file(\"./configs/vits2_vctk_base2.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a5d8345-cae7-4137-a700-0474a846112f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if hps.model.use_mel_posterior_encoder == True:\n",
    "    print(\"Using mel posterior encoder for VITS2\")\n",
    "    posterior_channels = 80  # vits2\n",
    "    hps.data.use_mel_posterior_encoder = True\n",
    "else:\n",
    "    print(\"Using lin posterior encoder for VITS1\")\n",
    "    posterior_channels = hps.data.filter_length // 2 + 1\n",
    "    hps.data.use_mel_posterior_encoder = False\n",
    "\n",
    "net_g = SynthesizerTrn(\n",
    "    len(symbols),\n",
    "    hps.data.n_mel_channels,\n",
    "    None,\n",
    "    n_speakers=hps.data.n_speakers,\n",
    "    **hps.model,\n",
    ").to(device)\n",
    "_ = net_g.eval()\n",
    "\n",
    "_ = utils.load_checkpoint(\"/path/to/the/pretrained.pth\", net_g, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "256a6792-da17-4f86-8b67-f7469200a252",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "text = \"\"\"VITS2 is Awesome!\"\"\"\n",
    "sid = 4\n",
    "\n",
    "stn_tst = get_text(text, hps)\n",
    "with torch.no_grad():\n",
    "    x_tst = stn_tst.to(device).unsqueeze(0)\n",
    "    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)\n",
    "    sid = torch.LongTensor([int(sid)]).to(device)\n",
    "    audio = (\n",
    "        net_g.infer(\n",
    "            x_tst,\n",
    "            x_tst_lengths,\n",
    "            sid=sid,\n",
    "            noise_scale=0.667,\n",
    "            noise_scale_w=0.8,\n",
    "            length_scale=1,\n",
    "        )[0][0, 0]\n",
    "        .data.cpu()\n",
    "        .float()\n",
    "        .numpy()\n",
    "    )\n",
    "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f0753ef-8103-4ef5-b431-3067a86e92b7",
   "metadata": {},
   "source": [
    "# Voice Conversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92aaa17e-19fa-4cc9-935b-02a6a4f897de",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n",
    "collate_fn = TextAudioSpeakerCollate()\n",
    "loader = DataLoader(dataset, num_workers=0, shuffle=False,\n",
    "    batch_size=1, pin_memory=False,\n",
    "    drop_last=True, collate_fn=collate_fn)\n",
    "data_list = list(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bb98f74-78ed-4ea1-894d-120af96335c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.to(device) for x in data_list[0]]\n",
    "    sid_tgt1 = torch.LongTensor([1]).to(device)\n",
    "    sid_tgt2 = torch.LongTensor([2]).to(device)\n",
    "    sid_tgt3 = torch.LongTensor([4]).to(device)\n",
    "    audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n",
    "    audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n",
    "    audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n",
    "print(\"Original SID: %d\" % sid_src.item())\n",
    "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt1.item())\n",
    "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt2.item())\n",
    "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt3.item())\n",
    "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
